/*
 * Copyright 2023 WebAssembly Community Group participants
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#define UNSUBTYPING_DEBUG 0

#include <cstddef>
#include <iterator>
#include <memory>

#if !UNSUBTYPING_DEBUG
#include <unordered_map>
#include <unordered_set>
#endif

#include "ir/effects.h"
#include "ir/localize.h"
#include "ir/module-utils.h"
#include "ir/names.h"
#include "ir/subtype-exprs.h"
#include "ir/type-updating.h"
#include "ir/utils.h"
#include "pass.h"
#include "support/index.h"
#include "wasm-traversal.h"
#include "wasm-type.h"
#include "wasm.h"

#if UNSUBTYPING_DEBUG
#include "support/insert_ordered.h"
#endif

#if UNSUBTYPING_DEBUG
#define DBG(x) x
#else
#define DBG(x)
#endif

// Compute and use the minimal subtype (and descriptor) relations required to
// maintain module validity and behavior. This minimal relation will be a subset
// of the original subtype (and descriptor) relations. Start by walking the IR
// and collecting pairs of types that need to be in the subtype relation for
// each expression to validate (or require a type to have a descriptor). For
// example, a local.set requires that the type of its operand be a subtype of
// the local's type. Casts do not generate subtypings at this point because it
// is not necessary for the cast target to be a subtype of the cast source for
// the cast to validate.
//
// From that initial subtype relation, we then start finding new subtypings (and
// descriptors) that are required by the subtypings we have found already. These
// transitively required subtypings (and descriptors) come from three sources.
//
// The first source is type definitions. Consider these type definitions:
//
//   (type $A (sub (struct (ref $X))))
//   (type $B (sub $A (struct (ref $Y))))
//
// If we have determined that $B must remain a subtype of $A, then we know that
// $Y must remain a subtype of $X as well, since the type definitions would not
// be valid otherwise. Similarly, knowing that $X must remain a subtype of $Y
// may transitively require other subtypings as well based on their type
// definitions.
//
// The second source of transitive subtyping requirements is casts. Although
// casting from one type to another does not necessarily require that those
// types are related, we do need to make sure that we do not change the
// behavior of casts by removing subtype relationships they might observe. For
// example, consider this module:
//
// (module
//  ;; original subtyping: $bot <: $mid <: $top
//  (type $top (sub (struct)))
//  (type $mid (sub $top (struct)))
//  (type $bot (sub $mid (struct)))
//
//  (func $f
//   (local $top (ref $top))
//   (local $mid (ref $mid))
//
//   ;; Requires $bot <: $top
//   (local.set $top (struct.new $bot))
//
//   ;; Cast $top to $mid
//   (local.set $mid (ref.cast (ref $mid) (local.get $top)))
//  )
// )
//
// The only subtype relation directly required by the IR for this module is $bot
// <: $top. However, if we optimized the module so that $bot <: $top was the
// only subtype relation, we would change the behavior of the cast. In the
// original module, a value of type (ref $bot) is cast to (ref $mid). The cast
// succeeds because in the original module, $bot <: $mid. If we optimize so that
// we have $bot <: $top and no other subtypings, though, the cast will fail
// because the value of type (ref $bot) no longer inhabits (ref $mid). To
// prevent the cast's behavior from changing, we need to ensure that $bot <:
// $mid.
//
// The set of subtyping requirements generated by a cast from $src to $dest is
// that for every known remaining subtype $v of $src, if $v <: $dest in the
// original module, then $v <: $dest in the optimized module. In other words,
// for every type $v of values we know can flow into the cast, if the cast would
// have succeeded for values of type $v before, then we know the cast must
// continue to succeed for values of type $v. These requirements arising from
// casts can also generate transitive requirements because we learn about new
// types of values that can flow into casts as we learn about new subtypes of
// cast sources.
//
// The third source of transitive subtyping requirements is the discovery of
// required descriptors (and vice versa). Subtyping and descriptors combine to
// form this diagram, where rightward arrows mean "described by":
//
//   A -> A.desc
//   ^    ^
//   |    |
//   B -> B.desc
//
// If any three of these types exist in these relations with the others, then
// the validation rules require that the fourth type also exist and be in these
// relations. The only exception is that A.desc is allowed to be missing. This
// complex and recursive relationship between subtyping and descriptor relations
// is why we optimize out unneeded descriptors in this pass rather than e.g.
// GlobalTypeOptimization.
//
// Starting with the initial subtype and descriptor relations determined by
// walking the IR, repeatedly search for new subtypings and descriptors by
// analyzing type definitions and casts until we reach a fixed point. This is
// the minimal subtype/descriptor relation that preserves module validity and
// behavior that can be found without a more precise analysis of types that
// might flow into each cast.

namespace wasm {

namespace {

#if UNSUBTYPING_DEBUG
template<typename K, typename V> using Map = InsertOrderedMap<K, V>;
template<typename T> using Set = InsertOrderedSet<T>;
#else
template<typename K, typename V> using Map = std::unordered_map<K, V>;
template<typename T> using Set = std::unordered_set<T>;
#endif

// A tree (or rather a forest) of types with the ability to query and set
// supertypes in constant time and efficiently iterate over supertypes and
// subtypes.
struct TypeTree {
  struct Node {
    // The type represented by this node.
    HeapType type;
    // The index of the parent (supertype) in the list of nodes. Set to the
    // index of this node if there is no parent.
    Index parent;
    // The index of this node in the parent's list of children, if any, enabling
    // O(1) updates.
    Index indexInParent = 0;
    // The indices of the children (subtypes) in the list of nodes.
    std::vector<Index> children;
    // The index of the described and descriptor types, if they are necessary.
    std::optional<Index> described;
    std::optional<Index> descriptor;

    Node(HeapType type, Index index) : type(type), parent(index) {}
  };

  std::vector<Node> nodes;
  Map<HeapType, Index> indices;

  void setSupertype(HeapType sub, HeapType super) {
    auto childIndex = getIndex(sub);
    auto parentIndex = getIndex(super);
    auto& childNode = nodes[childIndex];
    auto& parentNode = nodes[parentIndex];
    // Remove sub from its old supertype if necessary.
    if (auto oldParentIndex = childNode.parent; oldParentIndex != childIndex) {
      auto& oldParentNode = nodes[oldParentIndex];
      // Move sub to the back of its parent's children and then pop it.
      auto& children = oldParentNode.children;
      assert(children[childNode.indexInParent] == childIndex);
      auto& swappedNode = nodes[children.back()];
      assert(swappedNode.indexInParent == children.size() - 1);
      // Swap the indices in the parent's child vector.
      std::swap(children[childNode.indexInParent], children.back());
      // Swap the index in the kept child.
      swappedNode.indexInParent = childNode.indexInParent;
      children.pop_back();
    }
    childNode.parent = parentIndex;
    childNode.indexInParent = parentNode.children.size();
    parentNode.children.push_back(childIndex);
  }

  std::optional<HeapType> getSupertype(HeapType type) const {
    auto index = maybeGetIndex(type);
    if (!index) {
      return std::nullopt;
    }
    auto parentIndex = nodes[*index].parent;
    if (parentIndex == *index) {
      return std::nullopt;
    }
    return nodes[parentIndex].type;
  }

  void setDescriptor(HeapType described, HeapType descriptor) {
    auto describedIndex = getIndex(described);
    auto descriptorIndex = getIndex(descriptor);
    auto& describedNode = nodes[describedIndex];
    auto& descriptorNode = nodes[descriptorIndex];
    // We only ever set the descriptor once.
    assert(!describedNode.descriptor);
    assert(!descriptorNode.described);
    describedNode.descriptor = descriptorIndex;
    descriptorNode.described = describedIndex;
  }

  std::optional<HeapType> getDescriptor(HeapType type) const {
    auto index = maybeGetIndex(type);
    if (!index) {
      return std::nullopt;
    }
    if (auto descIndex = nodes[*index].descriptor) {
      return nodes[*descIndex].type;
    }
    return std::nullopt;
  }

  std::optional<HeapType> getDescribed(HeapType type) const {
    auto index = maybeGetIndex(type);
    if (!index) {
      return std::nullopt;
    }
    if (auto descIndex = nodes[*index].described) {
      return nodes[*descIndex].type;
    }
    return std::nullopt;
  }

  struct SupertypeIterator {
    using value_type = const HeapType;
    using difference_type = std::ptrdiff_t;
    using reference = const HeapType&;
    using pointer = const HeapType*;
    using iterator_category = std::input_iterator_tag;

    TypeTree* parent;
    std::optional<Index> index;

    bool operator==(const SupertypeIterator& other) const {
      return index == other.index;
    }
    bool operator!=(const SupertypeIterator& other) const {
      return !(*this == other);
    }
    const HeapType& operator*() const { return parent->nodes[*index].type; }
    const HeapType* operator->() const { return &*(*this); }
    SupertypeIterator& operator++() {
      auto parentIndex = parent->nodes[*index].parent;
      if (parentIndex == *index) {
        index = std::nullopt;
      } else {
        index = parentIndex;
      }
      return *this;
    }
    SupertypeIterator operator++(int) {
      auto it = *this;
      ++(*this);
      return it;
    }
  };

  struct Supertypes {
    TypeTree* parent;
    Index index;
    SupertypeIterator begin() { return {parent, index}; }
    SupertypeIterator end() { return {parent, std::nullopt}; }
  };

  Supertypes supertypes(HeapType type) { return {this, getIndex(type)}; }

  struct ImmediateSubtypeIterator {
    using value_type = const HeapType;
    using difference_type = std::ptrdiff_t;
    using reference = const HeapType&;
    using pointer = const HeapType*;
    using iterator_category = std::input_iterator_tag;

    TypeTree* parent;
    std::vector<Index>::const_iterator child;

    bool operator==(const ImmediateSubtypeIterator& other) const {
      return child == other.child;
    }
    bool operator!=(const ImmediateSubtypeIterator& other) const {
      return !(*this == other);
    }
    const HeapType& operator*() const { return parent->nodes[*child].type; }
    const HeapType* operator->() const { return &*(*this); }
    ImmediateSubtypeIterator& operator++() {
      ++child;
      return *this;
    }
    ImmediateSubtypeIterator operator++(int) {
      auto it = *this;
      ++(*this);
      return it;
    }
  };

  struct ImmediateSubtypes {
    TypeTree* parent;
    Index index;
    ImmediateSubtypeIterator begin() {
      return {parent, parent->nodes[index].children.begin()};
    }
    ImmediateSubtypeIterator end() {
      return {parent, parent->nodes[index].children.end()};
    }
  };

  ImmediateSubtypes immediateSubtypes(HeapType type) {
    return {this, getIndex(type)};
  }

  struct SubtypeIterator {
    using value_type = const HeapType;
    using difference_type = std::ptrdiff_t;
    using reference = const HeapType&;
    using pointer = const HeapType*;
    using iterator_category = std::input_iterator_tag;

    TypeTree* parent;

    // DFS stack of (node index, child index) pairs.
    std::vector<std::pair<Index, Index>> stack;

    bool operator==(const SubtypeIterator& other) {
      return stack == other.stack;
    }
    bool operator!=(const SubtypeIterator& other) { return !(*this == other); }
    const HeapType& operator*() const {
      return parent->nodes[stack.back().first].type;
    }
    const HeapType* operator->() const { return &*(*this); }
    SubtypeIterator& operator++() {
      while (true) {
        if (stack.empty()) {
          return *this;
        }
        auto& [index, childIndex] = stack.back();
        auto& children = parent->nodes[index].children;
        if (childIndex == children.size()) {
          stack.pop_back();
        } else {
          auto child = children[childIndex++];
          stack.push_back({child, 0u});
          return *this;
        }
      }
    }
    SubtypeIterator operator++(int) {
      auto it = *this;
      ++(*this);
      return it;
    }
  };

  struct Subtypes {
    TypeTree* parent;
    Index index;
    SubtypeIterator begin() { return {parent, {std::make_pair(index, 0u)}}; }
    SubtypeIterator end() { return {parent, {}}; }
  };

  Subtypes subtypes(HeapType type) { return {this, getIndex(type)}; }

#if UNSUBTYPING_DEBUG
  void dump(Module& wasm) {
    for (auto& node : nodes) {
      std::cerr << ModuleHeapType(wasm, node.type);
      if (auto super = getSupertype(node.type)) {
        std::cerr << " <: " << ModuleHeapType(wasm, *super);
      }
      if (auto desc = getDescribed(node.type)) {
        std::cerr << ", describes " << ModuleHeapType(wasm, *desc);
      }
      if (auto desc = getDescriptor(node.type)) {
        std::cerr << ", descriptor " << ModuleHeapType(wasm, *desc);
      }
      std::cerr << ", children:";
      for (auto child : node.children) {
        std::cerr << " " << ModuleHeapType(wasm, nodes[child].type);
      }
      std::cerr << '\n';
    }
  }
#endif

private:
  Index getIndex(HeapType type) {
    auto [it, inserted] = indices.insert({type, nodes.size()});
    if (inserted) {
      nodes.emplace_back(type, nodes.size());
    }
    return it->second;
  }

  std::optional<Index> maybeGetIndex(HeapType type) const {
    if (auto it = indices.find(type); it != indices.end()) {
      return it->second;
    }
    return std::nullopt;
  }
};

struct Unsubtyping : Pass {
  // The kind of work to process.
  enum class Kind { Subtype, Descriptor };
  // (sub, super) pairs that we have discovered but not yet processed.
  std::vector<std::tuple<Kind, HeapType, HeapType>> work;

  // Record the type tree with supertype and subtype relations in such a way
  // that we can add new supertype relationships in constant time.
  TypeTree types;

  // Map from cast source types to their destinations.
  Map<HeapType, std::vector<HeapType>> casts;

  DBG(Module* wasm = nullptr);

  void run(Module* wasm) override {
    DBG(this->wasm = wasm);
    if (!wasm->features.hasGC()) {
      return;
    }

    if (!getPassOptions().closedWorld) {
      Fatal() << "Unsubtyping requires --closed-world";
    }

    // Initialize the subtype relation based on what is immediately required to
    // keep the code and public types valid.
    analyzePublicTypes(*wasm);
    analyzeModule(*wasm);

    // Find further subtypings and iterate to a fixed point.
    while (!work.empty()) {
      auto [kind, a, b] = work.back();
      work.pop_back();
      switch (kind) {
        case Kind::Subtype:
          processSubtype(a, b);
          break;
        case Kind::Descriptor:
          processDescriptor(a, b);
          break;
      }
    }

    DBG(types.dump(*wasm));
    // If we removed a descriptor from a type, we may need to update its
    // allocation sites accordingly.
    fixupAllocations(*wasm);

    rewriteTypes(*wasm);

    // Cast types may be refinable if their source and target types are no
    // longer related. TODO: Experiment with running this only after checking
    // whether it is necessary.
    ReFinalize().run(getPassRunner(), wasm);
  }

  void noteSubtype(HeapType sub, HeapType super) {
    // Bottom types are uninteresting, but other basic heap types can be
    // interesting because of their interactions with casts.
    if (sub == super || sub.isBottom()) {
      return;
    }
    DBG(std::cerr << "noting " << ModuleHeapType(*wasm, sub)
                  << " <: " << ModuleHeapType(*wasm, super) << '\n');
    work.push_back({Kind::Subtype, sub, super});
  }

  void noteSubtype(Type sub, Type super) {
    if (sub.isTuple()) {
      assert(super.isTuple() && sub.size() == super.size());
      for (size_t i = 0, size = sub.size(); i < size; ++i) {
        noteSubtype(sub[i], super[i]);
      }
      return;
    }
    if (!sub.isRef() || !super.isRef()) {
      return;
    }
    noteSubtype(sub.getHeapType(), super.getHeapType());
  }

  void noteDescriptor(HeapType described, HeapType descriptor) {
    DBG(std::cerr << "noting " << ModuleHeapType(*wasm, described) << " -> "
                  << ModuleHeapType(*wasm, descriptor) << '\n');
    work.push_back({Kind::Descriptor, described, descriptor});
  }

  void analyzePublicTypes(Module& wasm) {
    // We cannot change supertypes for anything public.
    for (auto type : ModuleUtils::getPublicHeapTypes(wasm)) {
      if (auto super = type.getDeclaredSuperType()) {
        noteSubtype(type, *super);
      }
      if (auto desc = type.getDescriptorType()) {
        noteDescriptor(type, *desc);
      }
    }
  }

  void analyzeModule(Module& wasm) {
    struct Info {
      // (source, target) pairs for casts.
      Set<std::pair<HeapType, HeapType>> casts;

      // Observed (sub, super) subtype constraints.
      Set<std::pair<HeapType, HeapType>> subtypings;

      // Observed (described, descriptor) requirements.
      Set<std::pair<HeapType, HeapType>> descriptors;
    };

    struct Collector
      : ControlFlowWalker<Collector, SubtypingDiscoverer<Collector>> {
      using Super =
        ControlFlowWalker<Collector, SubtypingDiscoverer<Collector>>;

      Info& info;
      bool trapsNeverHappen;

      Collector(Info& info, bool trapsNeverHappen)
        : info(info), trapsNeverHappen(trapsNeverHappen) {}

      void noteSubtype(Type sub, Type super) {
        if (sub.isTuple()) {
          assert(super.isTuple() && sub.size() == super.size());
          for (size_t i = 0, size = sub.size(); i < size; ++i) {
            noteSubtype(sub[i], super[i]);
          }
          return;
        }
        if (!sub.isRef() || !super.isRef()) {
          return;
        }
        noteSubtype(sub.getHeapType(), super.getHeapType());
      }
      void noteSubtype(HeapType sub, HeapType super) {
        assert(HeapType::isSubType(sub, super));
        if (sub == super || sub.isBottom()) {
          return;
        }
        info.subtypings.insert({sub, super});
      }
      void noteSubtype(Type sub, Expression* super) {
        noteSubtype(sub, super->type);
      }
      void noteSubtype(Expression* sub, Type super) {
        noteSubtype(sub->type, super);
      }
      void noteSubtype(Expression* sub, Expression* super) {
        noteSubtype(sub->type, super->type);
      }
      void noteNonFlowSubtype(Expression* sub, Type super) {
        // This expression's type must be a subtype of |super|, but the value
        // does not flow anywhere - this is a static constraint. As the value
        // does not flow, it cannot reach anywhere else, which means we need
        // this in order to validate but it does not interact with casts. Given
        // that, if super is a basic type then we can simply ignore this: we
        // only remove subtyping between user types, so subtyping wrt basic
        // types is unchanged, and so this constraint will never be a problem.
        //
        // This is sort of a hack because in general to be precise we should not
        // just consider basic types here - in general, we should note for each
        // constraint whether it is a flow-based one or not, and only take the
        // flow-based ones into account when looking at the impact of casts.
        // However, in practice this is enough as the only non-trivial case of
        // |noteNonFlowSubtype| is for RefEq, which uses a basic type (eqref).
        // Other cases of non-flow subtyping end up trivial, e.g., the target of
        // a CallRef is compared to itself (and we ignore constraints of A :>
        // A). However, if we change how |noteNonFlowSubtype| is used in
        // SubtypingDiscoverer then we may need to generalize this.
        if (super.isRef() && super.getHeapType().isBasic()) {
          return;
        }

        // Otherwise, we must take this into account.
        noteSubtype(sub, super);
      }
      void noteCast(HeapType src, Type dstType) {
        auto dst = dstType.getHeapType();
        // Casts to self and casts that must fail because they have incompatible
        // types are uninteresting.
        if (dst == src) {
          return;
        }
        if (HeapType::isSubType(dst, src)) {
          if (dstType.isExact()) {
            // This cast only tests that the exact destination type is a subtype
            // of the source type and does not impose additional requirements on
            // subtypes of the destination type like a normal cast does.
            info.subtypings.insert({dst, src});
            return;
          }
          info.casts.insert({src, dst});
          return;
        }
        if (HeapType::isSubType(src, dst)) {
          // This is an upcast that will always succeed, but only if we ensure
          // src <: dst.
          info.subtypings.insert({src, dst});
        }
      }
      void noteCast(Expression* src, Type dst) {
        if (src->type.isRef() && dst.isRef()) {
          noteCast(src->type.getHeapType(), dst);
        }
      }
      void noteCast(Expression* src, Expression* dst) {
        if (src->type.isRef() && dst->type.isRef()) {
          noteCast(src->type.getHeapType(), dst->type);
        }
      }

      // Visitors for finding required descriptors.
      void noteDescribed(HeapType type) {
        auto desc = type.getDescriptorType();
        assert(desc);
        info.descriptors.insert({type, *desc});
      }
      void noteDescriptor(HeapType type) {
        auto desc = type.getDescribedType();
        assert(desc);
        info.descriptors.insert({*desc, type});
      }
      void visitRefGetDesc(RefGetDesc* curr) {
        Super::visitRefGetDesc(curr);
        if (!curr->ref->type.isStruct()) {
          return;
        }
        noteDescribed(curr->ref->type.getHeapType());
      }
      void visitRefCast(RefCast* curr) {
        Super::visitRefCast(curr);
        if (!curr->desc || !curr->desc->type.isStruct()) {
          return;
        }
        noteDescriptor(curr->desc->type.getHeapType());
      }
      void visitBrOn(BrOn* curr) {
        Super::visitBrOn(curr);
        if (!curr->desc || !curr->desc->type.isStruct()) {
          return;
        }
        noteDescriptor(curr->desc->type.getHeapType());
      }
      void visitStructNew(StructNew* curr) {
        Super::visitStructNew(curr);
        if (curr->type == Type::unreachable || !curr->desc) {
          return;
        }
        // Normally we do not treat struct.new as requiring a descriptor, even
        // if it has one. We are happy to optimize out descriptors that are set
        // in allocations and then never used. But if the descriptor is nullable
        // and outside a function context and we assume it may be null and cause
        // a trap, then we have no way to preserve that trap without keeping the
        // descriptor around.
        if (trapsNeverHappen || getFunction() ||
            curr->desc->type.isNonNullable()) {
          return;
        }
        // We must preserve the potential trap. When we update the instructions
        // later we will move this allocation to a new global if necessary to
        // preserve the potential trap even if a parent of the current
        // expression is removed.
        noteDescribed(curr->type.getHeapType());
      }
    };

    bool trapsNeverHappen = getPassOptions().trapsNeverHappen;

    // Collect subtyping constraints and casts from functions in parallel.
    ModuleUtils::ParallelFunctionAnalysis<Info> analysis(
      wasm, [&](Function* func, Info& info) {
        if (!func->imported()) {
          Collector(info, trapsNeverHappen).walkFunctionInModule(func, &wasm);
        }
      });

    Info collectedInfo;
    for (auto& [_, info] : analysis.map) {
      collectedInfo.casts.insert(info.casts.begin(), info.casts.end());
      collectedInfo.subtypings.insert(info.subtypings.begin(),
                                      info.subtypings.end());
      collectedInfo.descriptors.insert(info.descriptors.begin(),
                                       info.descriptors.end());
    }

    // Collect constraints from module-level code as well.
    Collector collector(collectedInfo, trapsNeverHappen);
    collector.walkModuleCode(&wasm);
    collector.setModule(&wasm);
    for (auto& global : wasm.globals) {
      collector.visitGlobal(global.get());
    }
    for (auto& segment : wasm.elementSegments) {
      collector.visitElementSegment(segment.get());
    }

    // Prepare the collected information for the upcoming processing loop.
    for (auto& [sub, super] : collectedInfo.subtypings) {
      noteSubtype(sub, super);
    }
    for (auto [src, dst] : collectedInfo.casts) {
      casts[src].push_back(dst);
    }
    for (auto [described, descriptor] : collectedInfo.descriptors) {
      noteDescriptor(described, descriptor);
    }
  }

  void processSubtype(HeapType sub, HeapType super) {
    DBG(std::cerr << "processing " << ModuleHeapType(*wasm, sub)
                  << " <: " << ModuleHeapType(*wasm, super) << '\n');
    assert(HeapType::isSubType(sub, super));
    auto oldSuper = types.getSupertype(sub);
    if (oldSuper) {
      // We already had a recorded supertype. The new supertype might be
      // deeper,shallower, or equal to the old supertype. We must recursively
      // note the relationship between the old and new supertypes.
      if (super == *oldSuper) {
        // Nothing new to do here.
        return;
      }
      if (HeapType::isSubType(*oldSuper, super)) {
        // sub <: oldSuper <: super
        noteSubtype(*oldSuper, super);
        // We already handled sub <: oldSuper, so we're done.
        return;
      }
      // sub <: super <: oldSuper
      // Eagerly process super <: oldSuper first. This ensures that sub and
      // super will already be in the same tree when we process them below, so
      // when we process casts we will know that we only need to process up to
      // oldSuper.
      processSubtype(super, *oldSuper);
    }

    types.setSupertype(sub, super);

    // Complete the descriptor squares to the left and right of the new
    // subtyping edge if those squares can possibly exist based on the original
    // types.
    if (super.getDescribedType()) {
      completeDescriptorSquare(
        types.getDescribed(super), super, types.getDescribed(sub), sub);
    }
    if (super.getDescriptorType()) {
      completeDescriptorSquare(
        super, types.getDescriptor(super), sub, types.getDescriptor(sub));
    }

    // Find the implied subtypings from the type definitions and casts.
    processDefinitions(sub, super);
    processCasts(sub, super, oldSuper);
  }

  void processDescriptor(HeapType described, HeapType descriptor) {
    DBG(std::cerr << "processing " << ModuleHeapType(*wasm, described) << " -> "
                  << ModuleHeapType(*wasm, descriptor) << '\n');
    assert(described.getDescriptorType() &&
           *described.getDescriptorType() == descriptor);
    if (auto oldDesc = types.getDescriptor(described)) {
      // We already know about this descriptor.
      assert(*oldDesc == descriptor);
      return;
    }

    types.setDescriptor(described, descriptor);

    // Complete the descriptor squares above and below the new descriptor edge.
    completeDescriptorSquare(
      std::nullopt, types.getSupertype(descriptor), described, descriptor);
    for (auto sub : types.immediateSubtypes(described)) {
      completeDescriptorSquare(
        described, descriptor, sub, types.getDescriptor(sub));
    }
    for (auto subDesc : types.immediateSubtypes(descriptor)) {
      completeDescriptorSquare(
        described, descriptor, types.getDescribed(subDesc), subDesc);
    }
  }

  void processDefinitions(HeapType sub, HeapType super) {
    if (super.isBasic()) {
      return;
    }
    switch (sub.getKind()) {
      case HeapTypeKind::Func: {
        auto sig = sub.getSignature();
        auto superSig = super.getSignature();
        noteSubtype(superSig.params, sig.params);
        noteSubtype(sig.results, superSig.results);
        break;
      }
      case HeapTypeKind::Struct: {
        const auto& fields = sub.getStruct().fields;
        const auto& superFields = super.getStruct().fields;
        for (size_t i = 0, size = superFields.size(); i < size; ++i) {
          noteSubtype(fields[i].type, superFields[i].type);
        }
        break;
      }
      case HeapTypeKind::Array: {
        auto elem = sub.getArray().element;
        noteSubtype(elem.type, super.getArray().element.type);
        break;
      }
      case HeapTypeKind::Cont:
        WASM_UNREACHABLE("TODO: cont");
      case HeapTypeKind::Basic:
        WASM_UNREACHABLE("unexpected kind");
    }
  }

  void
  processCasts(HeapType sub, HeapType super, std::optional<HeapType> oldSuper) {
    // We are either attaching the one tree rooted at `sub` under a new
    // supertype in another tree, or we are reparenting `sub` below a
    // descendent of `oldSuper` in the same tree. In the former case, we must
    // evaluate `sub` and all its subtypes against all its new supertypes and
    // their cast destinations. In the latter case, `sub` and all its subtypes
    // must have already been evaluated against `oldSuper` and its supertypes,
    // so we only need to additionally evaluate them against supertypes up to
    // `oldSuper`.
    for (auto type : types.subtypes(sub)) {
      for (auto src : types.supertypes(super)) {
        if (oldSuper && src == *oldSuper) {
          break;
        }
        for (auto dst : casts[src]) {
          if (HeapType::isSubType(type, dst)) {
            noteSubtype(type, dst);
          }
        }
      }
    }
  }

  void completeDescriptorSquare(std::optional<HeapType> super,
                                std::optional<HeapType> superDesc,
                                std::optional<HeapType> sub,
                                std::optional<HeapType> subDesc) {
    if ((super && super->isBasic()) || (superDesc && superDesc->isBasic())) {
      // Basic types do not have descriptors or described types, so do not form
      // descriptor squares.
      return;
    }
    if (bool(super) + bool(superDesc) + bool(sub) + bool(subDesc) < 3) {
      // We must have two adjacent edges (involving at least 3 types) for there
      // to be any further requirements.
      return;
    }
    // There may be up to one missing type. Look it up using its original
    // descriptor relation with the present types and add the missing edges.
    if (!super) {
      super = superDesc->getDescribedType();
    } else if (!sub) {
      sub = subDesc->getDescribedType();
    } else if (!subDesc) {
      subDesc = sub->getDescriptorType();
    } else if (!superDesc) {
      // This is the only type that is allowed to be missing.
      return;
    }
    // Add all the edges. Don't worry about duplicating existing edges because
    // checking whether they're necessary now would be about as expensive as
    // discarding them later.
    // TODO: We will be able to assume this once we update the descriptor
    // validation rules.
    if (HeapType::isSubType(*sub, *super)) {
      noteSubtype(*sub, *super);
    }
    noteSubtype(*subDesc, *superDesc);
    noteDescriptor(*super, *superDesc);
    noteDescriptor(*sub, *subDesc);
  }

  void rewriteTypes(Module& wasm) {
    struct Rewriter : GlobalTypeRewriter {
      Unsubtyping& parent;
      Rewriter(Unsubtyping& parent, Module& wasm)
        : GlobalTypeRewriter(wasm), parent(parent) {}
      std::optional<HeapType> getDeclaredSuperType(HeapType type) override {
        if (auto super = parent.types.getSupertype(type);
            super && !super->isBasic()) {
          return *super;
        }
        return std::nullopt;
      }
      void modifyTypeBuilderEntry(TypeBuilder& typeBuilder,
                                  Index i,
                                  HeapType oldType) override {
        if (!parent.types.getDescribed(oldType)) {
          typeBuilder[i].describes(std::nullopt);
        }
        if (!parent.types.getDescriptor(oldType)) {
          typeBuilder[i].descriptor(std::nullopt);
        }
      }
    };
    Rewriter(*this, wasm).update();
  }

  void fixupAllocations(Module& wasm) {
    if (!wasm.features.hasCustomDescriptors()) {
      return;
    }
    // TODO: Consider running the fixup only if we are actually removing any
    // descriptors. This would require a better way of detecting this than
    // collecing and iterating over all the types, though.
    struct Rewriter : WalkerPass<PostWalker<Rewriter>> {
      const TypeTree& types;

      // Allocations that might trap that have been removed from module-level
      // initializers. These need to be placed in new globals to preserve any
      // instantiation-time traps.
      std::vector<Expression*> removedTrappingInits;

      Rewriter(const TypeTree& types) : types(types) {}

      bool isFunctionParallel() override { return true; }
      // Only introduces locals that are set immediately before they are used.
      bool requiresNonNullableLocalFixups() override { return false; }
      std::unique_ptr<Pass> create() override {
        return std::make_unique<Rewriter>(types);
      }

      void visitStructNew(StructNew* curr) {
        if (curr->type == Type::unreachable) {
          return;
        }
        if (!curr->desc) {
          return;
        }
        if (types.getDescriptor(curr->type.getHeapType())) {
          return;
        }
        // We need to drop the descriptor argument. In a function context, use
        // ChildLocalizer. Outside a function context just drop the operand
        // because there can be no side effects anyway.
        if (auto* func = getFunction()) {
          // Preserve a trap from a null descriptor if necessary.
          if (!getPassOptions().trapsNeverHappen &&
              curr->desc->type.isNullable()) {
            curr->desc =
              Builder(*getModule()).makeRefAs(RefAsNonNull, curr->desc);
          }
          auto* block =
            ChildLocalizer(curr, func, *getModule(), getPassOptions())
              .getChildrenReplacement();
          block->list.push_back(curr);
          block->type = curr->type;
          replaceCurrent(block);
        } else {
          // We are dropping this descriptor, but it might have a potential trap
          // nested inside it. In that case we need to preserve the trap by
          // moving this descriptor to a new global.
          if (curr->desc->is<StructNew>() &&
              EffectAnalyzer(getPassOptions(), *getModule(), curr->desc).trap) {
            removedTrappingInits.push_back(curr->desc);
          }
        }
        curr->desc = nullptr;
      }
    };

    Rewriter rewriter(types);
    rewriter.run(getPassRunner(), &wasm);
    rewriter.runOnModuleCode(getPassRunner(), &wasm);

    // Insert globals necessary to preserve instantiation-time trapping of
    // removed allocations.
    for (Index i = 0; i < rewriter.removedTrappingInits.size(); ++i) {
      auto* curr = rewriter.removedTrappingInits[i];
      auto name = Names::getValidGlobalName(
        wasm, std::string("unsubtyping-removed-") + std::to_string(i));
      wasm.addGlobal(
        Builder::makeGlobal(name, curr->type, curr, Builder::Immutable));
    }
  }
};

} // anonymous namespace

Pass* createUnsubtypingPass() { return new Unsubtyping(); }

} // namespace wasm
